Skip to content

[MLIR][XeGPU] Refactor xegpu-wg-to-sg tests #149204

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 30, 2025

Conversation

nbpatel
Copy link
Contributor

@nbpatel nbpatel commented Jul 16, 2025

This PR refactors the xegpu-wg-to-sg.mlir tests to use larger shapes which resemble closer to workgroup level programming.

@nbpatel nbpatel marked this pull request as ready for review July 16, 2025 22:26
@nbpatel nbpatel requested review from chencha3 and Jianhui-Li July 16, 2025 22:26
@llvmbot
Copy link
Member

llvmbot commented Jul 16, 2025

@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir

Author: Nishant Patel (nbpatel)

Changes

This PR refactors the xegpu-wg-to-sg.mlir tests to use larger shapes which resemble closer to workgroup level programming.


Patch is 31.49 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/149204.diff

2 Files Affected:

  • (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir (+60-60)
  • (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir (+108-104)
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
index c6124f90e0f48..f1f00446366b3 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
@@ -2,104 +2,104 @@
 
 gpu.module @test_round_robin_assignment {
   // CHECK-LABEL: create_nd_tdesc
-  // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
-  gpu.func @create_nd_tdesc(%src: memref<24x32xf32>) {
-      // CHECK-COUNT-12: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<24x32xf32>
-      // CHECK-SAME: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+  gpu.func @create_nd_tdesc(%src: memref<256x128xf32>) {
+      // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<256x128xf32>
+      // CHECK-SAME: -> !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [8, 4], lane_data = [1, 1]>>
       // CHECK-NOT: xegpu.create_nd_tdesc
-      %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
-        -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+      %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+        -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
       gpu.return
     }
 
   // CHECK-LABEL: load_nd_tdesc
-  // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
-  gpu.func @load_nd_tdesc(%src: memref<24x32xf32>) {
-      %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
-        -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
-      // CHECK-COUNT-12: xegpu.load_nd %{{.*}}
-      // CHECK-SAME-COUNT-12: : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
-      // CHECK-SAME-COUNT-12: -> vector<2x2xf32>
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+  gpu.func @load_nd_tdesc(%src: memref<256x128xf32>) {
+      %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+        -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
+      // CHECK-COUNT-4: xegpu.load_nd %{{.*}}
+      // CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [8, 4], lane_data = [1, 1]>>
+      // CHECK-SAME-COUNT-4: -> vector<16x16xf32>
       // CHECK-NOT: xegpu.load_nd
       %load =  xegpu.load_nd %tdesc
-        : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
-        -> vector<24x32xf32>
+        : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
+        -> vector<256x128xf32>
       gpu.return
     }
 
   // CHECK-LABEL: store_nd
-  // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
-  gpu.func @store_nd(%src: memref<24x32xf32>) {
-      %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
-        -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
-      // CHECK-COUNT-12: xegpu.store_nd %{{.*}}, %{{.*}}
-      // CHECK-SAME-COUNT-12: : vector<2x2xf32>, !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+  gpu.func @store_nd(%src: memref<256x128xf32>) {
+      %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+        -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
+      // CHECK-COUNT-4: xegpu.store_nd %{{.*}}, %{{.*}}
+      // CHECK-SAME-COUNT-4: : vector<16x16xf32>, !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [8, 4], lane_data = [1, 1]>>
       // CHECK-NOT : xegpu.store_nd
       %load = xegpu.load_nd %tdesc
-        : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
-        -> vector<24x32xf32>
+        : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
+        -> vector<256x128xf32>
       xegpu.store_nd %load, %tdesc
-        : vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+        : vector<256x128xf32>, !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
       gpu.return
   }
 
   // CHECK-LABEL: update_nd
-  // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
-  gpu.func @update_nd(%src: memref<24x32xf32>){
-    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
-      ->  !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
-    // CHECK-COUNT-12: xegpu.update_nd_offset %{{.*}}, [0, 16]
-    // CHECK-SAME-COUNT-12: : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+  gpu.func @update_nd(%src: memref<256x128xf32>){
+    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+      ->  !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
+    // CHECK-COUNT-4: xegpu.update_nd_offset %{{.*}}, [0, 16]
+    // CHECK-SAME-COUNT-4: : !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [8, 4], lane_data = [1, 1]>>>
     // CHECK-NOT: xegpu.update_nd_offset
     %update = xegpu.update_nd_offset %tdesc, [0, 16]
-      : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+      : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
     gpu.return
   }
 
   // CHECK-LABEL: dpas
-  // CHECK-SAME: (%[[ARG_0:.*]]: memref<8x8xf32>, %[[ARG_1:.*]]: memref<8x8xf32>, %[[ARG_2:.*]]: memref<8x8xf32>)
-  gpu.func @dpas(%a: memref<8x8xf32>, %b: memref<8x8xf32>, %c: memref<8x8xf32>) {
-    // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<8x8xf32>
-    // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+  // CHECK-SAME: (%[[ARG_0:.*]]: memref<256x128xf32>, %[[ARG_1:.*]]: memref<128x256xf32>, %[[ARG_2:.*]]: memref<256x256xf32>)
+  gpu.func @dpas(%a: memref<256x128xf32>, %b: memref<128x256xf32>, %c: memref<256x256xf32>) {
+    // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<256x128xf32>
+    // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [8, 4], lane_data = [1, 1]>>
     // CHECK-NOT: xegpu.create_nd_tdesc
-    // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_1]][%{{.*}}, %{{.*}}] : memref<8x8xf32>
-    // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+    // CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_1]][%{{.*}}, %{{.*}}] : memref<128x256xf32>
+    // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [4, 8], lane_data = [1, 1]>>
     // CHECK-NOT: xegpu.create_nd_tdesc
-    // CHECK-COUNT-4:  xegpu.create_nd_tdesc %{{.*}}[%{{.*}}, %{{.*}}] : memref<8x8xf32>
-    // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+    // CHECK-COUNT-4:  xegpu.create_nd_tdesc %{{.*}}[%{{.*}}, %{{.*}}] : memref<256x256xf32>
+    // CHECK-SAME-COUNT-4: -> !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [8, 8], lane_data = [1, 1]>>
     // CHECK-NOT: xegpu.create_nd_tdesc
     // CHECK-COUNT-16: xegpu.dpas %{{.*}}, %{{.*}}
-    // CHECK-SAME-COUNT-16: {layout = #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>}
-    // CHECK-SAME-COUNT-16: : vector<2x2xf32>, vector<2x2xf32> -> vector<2x2xf32>
+    // CHECK-SAME-COUNT-16: {layout = #xegpu.layout<lane_layout = [8, 8], lane_data = [1, 1]>}
+    // CHECK-SAME-COUNT-16: : vector<16x16xf32>, vector<16x16xf32> -> vector<16x16xf32>
     // CHECK-NOT: xegpu.dpas
-    %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<8x8xf32>
-      -> !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+    %tdesc_a = xegpu.create_nd_tdesc %a[0, 0] : memref<256x128xf32>
+      -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
     %load_a =  xegpu.load_nd %tdesc_a
-      : !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
-      -> vector<8x8xf32>
-    %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<8x8xf32>
-      -> !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+      : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
+      -> vector<256x128xf32>
+    %tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<128x256xf32>
+      -> !xegpu.tensor_desc<128x256xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [16, 16], lane_layout = [4, 8], lane_data = [1, 1]>>
     %load_b =  xegpu.load_nd %tdesc_b
-      : !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
-      -> vector<8x8xf32>
-    %tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<8x8xf32>
-      -> !xegpu.tensor_desc<8x8xf32, #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+      : !xegpu.tensor_desc<128x256xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [16, 16], lane_layout = [4, 8], lane_data = [1, 1]>>
+      -> vector<128x256xf32>
+    %tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<256x256xf32>
+      -> !xegpu.tensor_desc<256x256xf32, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16], lane_layout = [8, 8], lane_data = [1, 1]>>
     %dpas = xegpu.dpas %load_a, %load_b
-      {layout_result_0 =  #xegpu.layout<sg_layout = [2, 2], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>}
-      : vector<8x8xf32>, vector<8x8xf32> -> vector<8x8xf32>
+      {layout_result_0 =  #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16], lane_layout = [8, 8], lane_data = [1, 1]>}
+      : vector<256x128xf32>, vector<128x256xf32> -> vector<256x256xf32>
     gpu.return
   }
 
   // CHECK-LABEL: prefetch_nd_tdesc
-  // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
-  gpu.func @prefetch_nd_tdesc(%src: memref<24x32xf32>) {
-    // CHECK-COUNT-12: xegpu.prefetch_nd %{{.*}}
-    // CHECK-SAME-COUNT-12 : !xegpu.tensor_desc<2x2xf32, #xegpu.layout<lane_layout = [2, 2], lane_data = [1, 1]>>
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+  gpu.func @prefetch_nd_tdesc(%src: memref<256x128xf32>) {
+    // CHECK-COUNT-4: xegpu.prefetch_nd %{{.*}}
+    // CHECK-SAME-COUNT-4: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<lane_layout = [8, 4], lane_data = [1, 1]>>
     // CHECK-NOT: xegpu.prefetch_nd
-    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
-      -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+      -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
     xegpu.prefetch_nd %tdesc
-      : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [4, 4], sg_data = [2, 2], lane_layout = [2, 2], lane_data = [1, 1]>>
+      : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
     gpu.return
   }
 
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index 44b11c304cc80..2ae97a42cfdd4 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -4,169 +4,173 @@
 //CHECK: #map1 = affine_map<()[s0] -> (s0 mod 4)>
 gpu.module @test_1_1_assignment {
   // CHECK-LABEL: create_nd_tdesc
-  // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
-  gpu.func @create_nd_tdesc(%src: memref<24x32xf32>) {
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+  gpu.func @create_nd_tdesc(%src: memref<256x128xf32>) {
   // CHECK: %[[SGID:.*]] = gpu.subgroup_id
-  // CHECK: %[[C12:.*]] = arith.constant 12 : index
-  // CHECK: %[[C4:.*]] = arith.constant 4 : index
   // CHECK: %[[C8:.*]] = arith.constant 8 : index
+  // CHECK: %[[C32:.*]] = arith.constant 32 : index
+  // CHECK: %[[C4:.*]] = arith.constant 4 : index
+  // CHECK: %[[C32_0:.*]] = arith.constant 32 : index
+  // CHECK: %[[C4_1:.*]] = arith.constant 4 : index
   // CHECK: %[[DIV:.*]] = affine.apply #map()[%[[SGID]]]
   // CHECK: %[[REM:.*]] = affine.apply #map1()[%[[SGID]]]
-  // CHECK: %[[MUL1:.*]] = index.mul %[[DIV]], %[[C12]]
-  // CHECK: %[[MUL2:.*]] = index.mul %[[REM]], %[[C8]]
-  // CHECK: %[[C24:.*]] = arith.constant 24 : index
-  // CHECK: %[[MOD:.*]] = index.remu %[[MUL1]], %[[C24]]
+  // CHECK: %[[MUL1:.*]] = index.mul %[[DIV]], %[[C32]]
+  // CHECK: %[[MUL2:.*]] = index.mul %[[REM]], %[[C32_0]]
   // CHECK: %[[C0:.*]] = arith.constant 0 : index
-  // CHECK: %[[ADD1:.*]] = index.add %[[MOD]], %[[C0]]
-  // CHECK: %[[C32:.*]] = arith.constant 32 : index
-  // CHECK: %[[MOD1:.*]] = index.remu %[[MUL2]], %[[C32]]
-  // CHECK: %[[C0_1:.*]] = arith.constant 0 : index
-  // CHECK: %[[ADD2:.*]] = index.add %[[MOD1]], %[[C0_1]]
-  // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%[[ADD1]], %[[ADD2]]] : memref<24x32xf32>
-  // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+  // CHECK: %[[C256:.*]] = arith.constant 256 : index
+  // CHECK: %[[MOD:.*]] = index.remu %[[MUL1]], %[[C256]]
+  // CHECK: %[[C0_2:.*]] = arith.constant 0 : index
+  // CHECK: %[[ADD1:.*]] = index.add %[[MOD]], %[[C0_2]]
+  // CHECK: %[[C0_3:.*]] = arith.constant 0 : index
+  // CHECK: %[[C128:.*]] = arith.constant 128 : index
+  // CHECK: %[[MOD1:.*]] = index.remu %[[MUL2]], %[[C128]]
+  // CHECK: %[[C0_4:.*]] = arith.constant 0 : index
+  // CHECK: %[[ADD2:.*]] = index.add %[[MOD1]], %[[C0_4]]
+  // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][%[[ADD1]], %[[ADD2]]] : memref<256x128xf32>
+  // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [8, 4], lane_data = [1, 1]>>
   // CHECK: gpu.return
-  %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
-    -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+  %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+    -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
   gpu.return
   }
 
   // CHECK-LABEL: load_nd_tdesc
-  // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
-  gpu.func @load_nd_tdesc(%src: memref<24x32xf32>) {
-    // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32>
-    // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+  gpu.func @load_nd_tdesc(%src: memref<256x128xf32>) {
+    // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32>
+    // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [8, 4], lane_data = [1, 1]>>
     // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]]
-    // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
-    // CHECK-SAME: -> vector<12x8xf32>
-    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
-      -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+    // CHECK-SAME: : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [8, 4], lane_data = [1, 1]>>
+    // CHECK-SAME: -> vector<32x32xf32>
+    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+      -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
     %load =  xegpu.load_nd %tdesc
-      : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
-      -> vector<24x32xf32>
+      : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
+      -> vector<256x128xf32>
     gpu.return
   }
 
   // CHECK-LABEL: store_nd
-  // CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
-  gpu.func @store_nd(%src: memref<24x32xf32>) {
-    // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32>
-    // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+  gpu.func @store_nd(%src: memref<256x128xf32>) {
+    // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32>
+    // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [8, 4], lane_data = [1, 1]>>
     // CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]]
-    // CHECK-SAME: : !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
-    // CHECK-SAME: -> vector<12x8xf32>
+    // CHECK-SAME: : !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [8, 4], lane_data = [1, 1]>>
+    // CHECK-SAME: -> vector<32x32xf32>
     // CHECK: xegpu.store_nd %[[LOAD]], %[[TDESC]]
-    // CHECK-SAME: : vector<12x8xf32>, !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
-    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32>
-      -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+    // CHECK-SAME: : vector<32x32xf32>, !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [8, 4], lane_data = [1, 1]>>
+    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
+      -> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
     %load = xegpu.load_nd %tdesc
-      : !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
-      -> vector<24x32xf32>
+      : !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
+      -> vector<256x128xf32>
     xegpu.store_nd %load, %tdesc
-      : vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout<sg_layout = [2, 4], sg_data = [12, 8], lane_layout = [2, 8], lane_data = [1, 1]>>
+      : vector<256x128xf32>, !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [8, 4], lane_data = [1, 1]>>
     gpu.return
 }
 
 // CHECK-LABEL: update_nd
-// CHECK-SAME: %[[ARG_0:.*]]: memref<24x32xf32>
-gpu.func @update_nd(%src: memref<24x32xf32>){
-  // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<24x32xf32>
-  // CHECK-SAME: -> !xegpu.tensor_desc<12x8xf32, #xegpu.layout<lane_layout = [2, 8], lane_data = [1, 1]>>
+// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
+gpu.func @update_nd(%src: memref<256x128xf32>){
+  // CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32>
+  // CHECK-SAME: -> !xegpu.tensor_desc<32x32xf32, #xegpu.layout<lane_layout = [8, 4], lane_data = [1, 1]>>
   // CHECK: %[[UPDATE:.*]] = xegpu.update_nd_offset %[[TDESC]], [0, 16]
-  // CHECK-SAME: : !xegpu.tensor_de...
[truncated]

// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
gpu.func @create_nd_tdesc(%src: memref<256x128xf32>) {
// CHECK-COUNT-4: xegpu.create_nd_tdesc %[[ARG_0]][%{{.*}}, %{{.*}}] : memref<256x128xf32>
// CHECK-SAME: -> !xegpu.tensor_desc<16x16xf32, #xegpu.layout<lane_layout = [8, 4], lane_data = [1, 1]>>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lane_layout = [8, 4] => lane_layout = [1, 16]

: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
-> vector<256x128xf32>
%tdesc_b = xegpu.create_nd_tdesc %b[0, 0] : memref<128x256xf32>
-> !xegpu.tensor_desc<128x256xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [16, 16], lane_layout = [4, 8], lane_data = [1, 1]>>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lane_layout = [4, 8], lane_data = [1, 1] => lane_layout = [1, 16], lane_data = [2, 1] for bf16

: !xegpu.tensor_desc<128x256xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [16, 16], lane_layout = [4, 8], lane_data = [1, 1]>>
-> vector<128x256xf32>
%tdesc_c = xegpu.create_nd_tdesc %c[0, 0] : memref<256x256xf32>
-> !xegpu.tensor_desc<256x256xf32, #xegpu.layout<sg_layout = [8, 8], sg_data = [16, 16], lane_layout = [8, 8], lane_data = [1, 1]>>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dpas not support f32

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On that note, the op should be restricted in the future to reject unsupported types

// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
gpu.func @load_nd_tdesc(%src: memref<256x128xf32>) {
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16], lane_layout = [8, 4], lane_data = [1, 1]>>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider adding order=[1, 0]

@nbpatel nbpatel force-pushed the xegpu-refactor-tests branch from 53f1e5d to 82b951e Compare July 25, 2025 16:20
Copy link
Contributor

@chencha3 chencha3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@nbpatel
Copy link
Contributor Author

nbpatel commented Jul 30, 2025

@Jianhui-Li any more comments? I addressed all the feedback

Copy link
Contributor

@Jianhui-Li Jianhui-Li left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@nbpatel nbpatel merged commit 8070e9b into llvm:main Jul 30, 2025
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants